// Copyright (c) 2012 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "net/dns/mock_host_resolver.h"

#include <string>
#include <vector>

#include "base/bind.h"
#include "base/location.h"
#include "base/memory/ref_counted.h"
#include "base/single_thread_task_runner.h"
#include "base/stl_util.h"
#include "base/strings/pattern.h"
#include "base/strings/string_split.h"
#include "base/strings/string_util.h"
#include "base/threading/platform_thread.h"
#include "base/threading/thread_task_runner_handle.h"
#include "net/base/ip_address.h"
#include "net/base/ip_endpoint.h"
#include "net/base/net_errors.h"
#include "net/base/test_completion_callback.h"
#include "net/dns/host_cache.h"

#if defined(OS_WIN)
#include "net/base/winsock_init.h"
#endif

namespace net {

namespace {

    // Cache size for the MockCachingHostResolver.
    const unsigned kMaxCacheEntries = 100;
    // TTL for the successful resolutions. Failures are not cached.
    const unsigned kCacheEntryTTLSeconds = 60;

} // namespace

int ParseAddressList(const std::string& host_list,
    const std::string& canonical_name,
    AddressList* addrlist)
{
    *addrlist = AddressList();
    addrlist->set_canonical_name(canonical_name);
    for (const base::StringPiece& address : base::SplitStringPiece(
             host_list, ",", base::TRIM_WHITESPACE, base::SPLIT_WANT_ALL)) {
        IPAddress ip_address;
        if (!ip_address.AssignFromIPLiteral(address)) {
            LOG(WARNING) << "Not a supported IP literal: " << address.as_string();
            return ERR_UNEXPECTED;
        }
        addrlist->push_back(IPEndPoint(ip_address, 0));
    }
    return OK;
}

struct MockHostResolverBase::Request {
    Request(const RequestInfo& req_info,
        AddressList* addr,
        const CompletionCallback& cb)
        : info(req_info)
        , addresses(addr)
        , callback(cb)
    {
    }
    RequestInfo info;
    AddressList* addresses;
    CompletionCallback callback;
};

MockHostResolverBase::~MockHostResolverBase()
{
    STLDeleteValues(&requests_);
}

int MockHostResolverBase::Resolve(const RequestInfo& info,
    RequestPriority priority,
    AddressList* addresses,
    const CompletionCallback& callback,
    RequestHandle* handle,
    const BoundNetLog& net_log)
{
    DCHECK(CalledOnValidThread());
    last_request_priority_ = priority;
    num_resolve_++;
    size_t id = next_request_id_++;
    int rv = ResolveFromIPLiteralOrCache(info, addresses);
    if (rv != ERR_DNS_CACHE_MISS) {
        return rv;
    }
    if (synchronous_mode_) {
        return ResolveProc(id, info, addresses);
    }
    // Store the request for asynchronous resolution
    Request* req = new Request(info, addresses, callback);
    requests_[id] = req;
    if (handle)
        *handle = reinterpret_cast<RequestHandle>(id);

    if (!ondemand_mode_) {
        base::ThreadTaskRunnerHandle::Get()->PostTask(
            FROM_HERE,
            base::Bind(&MockHostResolverBase::ResolveNow, AsWeakPtr(), id));
    }

    return ERR_IO_PENDING;
}

int MockHostResolverBase::ResolveFromCache(const RequestInfo& info,
    AddressList* addresses,
    const BoundNetLog& net_log)
{
    num_resolve_from_cache_++;
    DCHECK(CalledOnValidThread());
    next_request_id_++;
    int rv = ResolveFromIPLiteralOrCache(info, addresses);
    return rv;
}

void MockHostResolverBase::CancelRequest(RequestHandle handle)
{
    DCHECK(CalledOnValidThread());
    size_t id = reinterpret_cast<size_t>(handle);
    RequestMap::iterator it = requests_.find(id);
    if (it != requests_.end()) {
        std::unique_ptr<Request> req(it->second);
        requests_.erase(it);
    } else {
        NOTREACHED() << "CancelRequest must NOT be called after request is "
                        "complete or canceled.";
    }
}

HostCache* MockHostResolverBase::GetHostCache()
{
    return cache_.get();
}

void MockHostResolverBase::ResolveAllPending()
{
    DCHECK(CalledOnValidThread());
    DCHECK(ondemand_mode_);
    for (RequestMap::iterator i = requests_.begin(); i != requests_.end(); ++i) {
        base::ThreadTaskRunnerHandle::Get()->PostTask(
            FROM_HERE,
            base::Bind(&MockHostResolverBase::ResolveNow, AsWeakPtr(), i->first));
    }
}

// start id from 1 to distinguish from NULL RequestHandle
MockHostResolverBase::MockHostResolverBase(bool use_caching)
    : last_request_priority_(DEFAULT_PRIORITY)
    , synchronous_mode_(false)
    , ondemand_mode_(false)
    , next_request_id_(1)
    , num_resolve_(0)
    , num_resolve_from_cache_(0)
{
    rules_ = CreateCatchAllHostResolverProc();

    if (use_caching) {
        cache_.reset(new HostCache(kMaxCacheEntries));
    }
}

int MockHostResolverBase::ResolveFromIPLiteralOrCache(const RequestInfo& info,
    AddressList* addresses)
{
    IPAddress ip_address;
    if (ip_address.AssignFromIPLiteral(info.hostname())) {
        // This matches the behavior HostResolverImpl.
        if (info.address_family() != ADDRESS_FAMILY_UNSPECIFIED && info.address_family() != GetAddressFamily(ip_address)) {
            return ERR_NAME_NOT_RESOLVED;
        }

        *addresses = AddressList::CreateFromIPAddress(ip_address, info.port());
        if (info.host_resolver_flags() & HOST_RESOLVER_CANONNAME)
            addresses->SetDefaultCanonicalName();
        return OK;
    }
    int rv = ERR_DNS_CACHE_MISS;
    if (cache_.get() && info.allow_cached_response()) {
        HostCache::Key key(info.hostname(),
            info.address_family(),
            info.host_resolver_flags());
        const HostCache::Entry* entry = cache_->Lookup(key, base::TimeTicks::Now());
        if (entry) {
            rv = entry->error();
            if (rv == OK)
                *addresses = AddressList::CopyWithPort(entry->addresses(), info.port());
        }
    }
    return rv;
}

int MockHostResolverBase::ResolveProc(size_t id,
    const RequestInfo& info,
    AddressList* addresses)
{
    AddressList addr;
    int rv = rules_->Resolve(info.hostname(),
        info.address_family(),
        info.host_resolver_flags(),
        &addr,
        NULL);
    if (cache_.get()) {
        HostCache::Key key(info.hostname(),
            info.address_family(),
            info.host_resolver_flags());
        // Storing a failure with TTL 0 so that it overwrites previous value.
        base::TimeDelta ttl;
        if (rv == OK)
            ttl = base::TimeDelta::FromSeconds(kCacheEntryTTLSeconds);
        cache_->Set(key, HostCache::Entry(rv, addr), base::TimeTicks::Now(), ttl);
    }
    if (rv == OK)
        *addresses = AddressList::CopyWithPort(addr, info.port());
    return rv;
}

void MockHostResolverBase::ResolveNow(size_t id)
{
    RequestMap::iterator it = requests_.find(id);
    if (it == requests_.end())
        return; // was canceled

    std::unique_ptr<Request> req(it->second);
    requests_.erase(it);
    int rv = ResolveProc(id, req->info, req->addresses);
    if (!req->callback.is_null())
        req->callback.Run(rv);
}

//-----------------------------------------------------------------------------

struct RuleBasedHostResolverProc::Rule {
    enum ResolverType {
        kResolverTypeFail,
        kResolverTypeSystem,
        kResolverTypeIPLiteral,
    };

    ResolverType resolver_type;
    std::string host_pattern;
    AddressFamily address_family;
    HostResolverFlags host_resolver_flags;
    std::string replacement;
    std::string canonical_name;
    int latency_ms; // In milliseconds.

    Rule(ResolverType resolver_type,
        const std::string& host_pattern,
        AddressFamily address_family,
        HostResolverFlags host_resolver_flags,
        const std::string& replacement,
        const std::string& canonical_name,
        int latency_ms)
        : resolver_type(resolver_type)
        , host_pattern(host_pattern)
        , address_family(address_family)
        , host_resolver_flags(host_resolver_flags)
        , replacement(replacement)
        , canonical_name(canonical_name)
        , latency_ms(latency_ms)
    {
    }
};

RuleBasedHostResolverProc::RuleBasedHostResolverProc(HostResolverProc* previous)
    : HostResolverProc(previous)
{
}

void RuleBasedHostResolverProc::AddRule(const std::string& host_pattern,
    const std::string& replacement)
{
    AddRuleForAddressFamily(host_pattern, ADDRESS_FAMILY_UNSPECIFIED,
        replacement);
}

void RuleBasedHostResolverProc::AddRuleForAddressFamily(
    const std::string& host_pattern,
    AddressFamily address_family,
    const std::string& replacement)
{
    DCHECK(!replacement.empty());
    HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY | HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
    Rule rule(Rule::kResolverTypeSystem,
        host_pattern,
        address_family,
        flags,
        replacement,
        std::string(),
        0);
    AddRuleInternal(rule);
}

void RuleBasedHostResolverProc::AddIPLiteralRule(
    const std::string& host_pattern,
    const std::string& ip_literal,
    const std::string& canonical_name)
{
    // Literals are always resolved to themselves by HostResolverImpl,
    // consequently we do not support remapping them.
    IPAddress ip_address;
    DCHECK(!ip_address.AssignFromIPLiteral(host_pattern));
    HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY | HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
    if (!canonical_name.empty())
        flags |= HOST_RESOLVER_CANONNAME;
    Rule rule(Rule::kResolverTypeIPLiteral, host_pattern,
        ADDRESS_FAMILY_UNSPECIFIED, flags, ip_literal, canonical_name,
        0);
    AddRuleInternal(rule);
}

void RuleBasedHostResolverProc::AddRuleWithLatency(
    const std::string& host_pattern,
    const std::string& replacement,
    int latency_ms)
{
    DCHECK(!replacement.empty());
    HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY | HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
    Rule rule(Rule::kResolverTypeSystem,
        host_pattern,
        ADDRESS_FAMILY_UNSPECIFIED,
        flags,
        replacement,
        std::string(),
        latency_ms);
    AddRuleInternal(rule);
}

void RuleBasedHostResolverProc::AllowDirectLookup(
    const std::string& host_pattern)
{
    HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY | HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
    Rule rule(Rule::kResolverTypeSystem,
        host_pattern,
        ADDRESS_FAMILY_UNSPECIFIED,
        flags,
        std::string(),
        std::string(),
        0);
    AddRuleInternal(rule);
}

void RuleBasedHostResolverProc::AddSimulatedFailure(
    const std::string& host_pattern)
{
    HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY | HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
    Rule rule(Rule::kResolverTypeFail,
        host_pattern,
        ADDRESS_FAMILY_UNSPECIFIED,
        flags,
        std::string(),
        std::string(),
        0);
    AddRuleInternal(rule);
}

void RuleBasedHostResolverProc::ClearRules()
{
    base::AutoLock lock(rule_lock_);
    rules_.clear();
}

int RuleBasedHostResolverProc::Resolve(const std::string& host,
    AddressFamily address_family,
    HostResolverFlags host_resolver_flags,
    AddressList* addrlist,
    int* os_error)
{
    base::AutoLock lock(rule_lock_);
    RuleList::iterator r;
    for (r = rules_.begin(); r != rules_.end(); ++r) {
        bool matches_address_family = r->address_family == ADDRESS_FAMILY_UNSPECIFIED || r->address_family == address_family;
        // Ignore HOST_RESOLVER_SYSTEM_ONLY, since it should have no impact on
        // whether a rule matches.
        HostResolverFlags flags = host_resolver_flags & ~HOST_RESOLVER_SYSTEM_ONLY;
        // Flags match if all of the bitflags in host_resolver_flags are enabled
        // in the rule's host_resolver_flags. However, the rule may have additional
        // flags specified, in which case the flags should still be considered a
        // match.
        bool matches_flags = (r->host_resolver_flags & flags) == flags;
        if (matches_flags && matches_address_family && base::MatchPattern(host, r->host_pattern)) {
            if (r->latency_ms != 0) {
                base::PlatformThread::Sleep(
                    base::TimeDelta::FromMilliseconds(r->latency_ms));
            }

            // Remap to a new host.
            const std::string& effective_host = r->replacement.empty() ? host : r->replacement;

            // Apply the resolving function to the remapped hostname.
            switch (r->resolver_type) {
            case Rule::kResolverTypeFail:
                return ERR_NAME_NOT_RESOLVED;
            case Rule::kResolverTypeSystem:
#if defined(OS_WIN)
                EnsureWinsockInit();
#endif
                return SystemHostResolverCall(effective_host,
                    address_family,
                    host_resolver_flags,
                    addrlist, os_error);
            case Rule::kResolverTypeIPLiteral:
                return ParseAddressList(effective_host,
                    r->canonical_name,
                    addrlist);
            default:
                NOTREACHED();
                return ERR_UNEXPECTED;
            }
        }
    }
    return ResolveUsingPrevious(host, address_family,
        host_resolver_flags, addrlist, os_error);
}

RuleBasedHostResolverProc::~RuleBasedHostResolverProc()
{
}

void RuleBasedHostResolverProc::AddRuleInternal(const Rule& rule)
{
    base::AutoLock lock(rule_lock_);
    rules_.push_back(rule);
}

RuleBasedHostResolverProc* CreateCatchAllHostResolverProc()
{
    RuleBasedHostResolverProc* catchall = new RuleBasedHostResolverProc(NULL);
    catchall->AddIPLiteralRule("*", "127.0.0.1", "localhost");

    // Next add a rules-based layer the use controls.
    return new RuleBasedHostResolverProc(catchall);
}

//-----------------------------------------------------------------------------

int HangingHostResolver::Resolve(const RequestInfo& info,
    RequestPriority priority,
    AddressList* addresses,
    const CompletionCallback& callback,
    RequestHandle* out_req,
    const BoundNetLog& net_log)
{
    return ERR_IO_PENDING;
}

int HangingHostResolver::ResolveFromCache(const RequestInfo& info,
    AddressList* addresses,
    const BoundNetLog& net_log)
{
    return ERR_DNS_CACHE_MISS;
}

//-----------------------------------------------------------------------------

ScopedDefaultHostResolverProc::ScopedDefaultHostResolverProc() { }

ScopedDefaultHostResolverProc::ScopedDefaultHostResolverProc(
    HostResolverProc* proc)
{
    Init(proc);
}

ScopedDefaultHostResolverProc::~ScopedDefaultHostResolverProc()
{
    HostResolverProc* old_proc = HostResolverProc::SetDefault(previous_proc_.get());
    // The lifetimes of multiple instances must be nested.
    CHECK_EQ(old_proc, current_proc_.get());
}

void ScopedDefaultHostResolverProc::Init(HostResolverProc* proc)
{
    current_proc_ = proc;
    previous_proc_ = HostResolverProc::SetDefault(current_proc_.get());
    current_proc_->SetLastProc(previous_proc_.get());
}

} // namespace net
